################################################################################
#       Copyright (C) 2010 Michael Yurko <myurko@gmail.com>
#
#This file is part of pyQAP.
#
#pyQAP is free software: you can redistribute it and/or modify
#it under the terms of the GNU General Public License as published by
#the Free Software Foundation, either version 2 of the License, or
#(at your option) any later version.
#
#pyQAP is distributed in the hope that it will be useful,
#but WITHOUT ANY WARRANTY; without even the implied warranty of
#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#GNU General Public License for more details.
#
#You should have received a copy of the GNU General Public License
#along with pyQAP.  If not, see <http://www.gnu.org/licenses/>.
################################################################################

'''
contains Linear Assignment Problem Solvers.
'''

#imports
import numpy as np
cimport numpy as np

cdef extern from "numpy/arrayobject.h":
    PyArray_EMPTY(int ndims, np.npy_intp* dims, int type, bint fortran)
    cdef void import_array()
import_array()

#types
ctypedef Py_ssize_t np_int_t
ctypedef Py_ssize_t int_t
ctypedef np.float64_t float_t
ctypedef Py_ssize_t bool_t
int_tp = np.int
bool_tp = np.int

#constants
cdef float_t f_inf = 1e308
cdef float_t epsilon = 1e-6

########################################################################
#This currently uses an ugly hack due to a limitation in cython's buffer
#handling. Since buffers can only be used as function parameters or function
#local variables, instead this code directly accesses the data field of 
#arrays. The arrays passed *MUST* be C contiguous.
########################################################################

cdef class Hungarian:
    '''
    Contains the memory for all laps that are needed to be solved for the GLB.
    '''
    
    #class type declarations
    cdef public int_t max_n, n, zero_col,  zero_row
    cdef np.ndarray covered_rows, covered_cols, c, marked, path0, path1
    #pointer declarations
    cdef bool_t *covered_rows_p,  *covered_cols_p
    cdef np_int_t *marked_p, *path0_p, *path1_p
    cdef float_t *c_p
    #size declarations
    cdef int_t c_nd, marked_nd

    def __init__(self, max_n):
        '''
        Create a new Hungarian instance,. max_n is the maximum size of lap that
        memory will be allocated for.
        '''
        #arrays
        self.covered_rows = np.zeros((max_n), dtype=bool_tp)
        self.covered_cols = np.zeros((max_n), dtype=bool_tp)
        #matrices
        self.c = None
        self.marked = np.zeros((max_n, max_n), dtype=int_tp)
        self.path0 = np.zeros((max_n*2), dtype=int_tp)
        self.path1 = np.zeros((max_n*2), dtype=int_tp)
        #integers
        self.max_n = max_n
        self.n = 0
        self.zero_row = 0
        self.zero_col = 0
        #pointer declarations
        self.covered_cols_p = <bool_t*>self.covered_cols.data
        self.covered_rows_p = <bool_t*>self.covered_rows.data
        self.marked_p = <np_int_t*>self.marked.data
        self.path0_p = <np_int_t*>self.path0.data
        self.path1_p = <np_int_t*>self.path1.data
        #size declarations
        self.marked_nd = self.marked.shape[0]

    cpdef double get_cost(self, np.ndarray cost_matrix,np.ndarray pristine, n):
        '''
        Returns the optimal LAP cost with the given cost matrix of size n. Note:
        cost_matrix will be changed. Pass pristine as a copy of the matrix that
        won't be changed. It will be used to calculate the cost after the 
        optimal assignment has been found.
        '''
        self.n = n
        self.c = cost_matrix
        #get pointer and shape of c
        self.c_p = <float_t*>self.c.data
        self.c_nd = self.c.shape[0]
        self.zero_row = 0
        self.zero_col = 0

        cdef int step = 1

        #main loop
        while True:
            #determine correct step to call
            if step == 1:
                step = self.step1()
            elif step == 2:
                step = self.step2()
            elif step == 3:
                step = self.step3()
            elif step == 4:
                step = self.step4()
            elif step == 5:
                step = self.step5()
            elif step == 6:
                step = self.step6()
            else:
                #if done then break from loop
                break

        # Look for the starred columns and compute cost
        cdef double cost = 0.0
        #get pointers and shape
        cdef float_t* pristine_p = <float_t*>pristine.data
        cdef int_t pristine_nd = pristine.shape[0]
        for i in range(self.n):
            for j in range(self.n):
                if self.marked_p[i*self.marked_nd + j] == 1:
                    cost += pristine_p[i*pristine_nd+j]
                    
        self.full_cleanup()

        return cost

    cdef int step1(self):
        '''
        Reduce the matrix by finding the minimum in each row and then subtract
        it from the whole row. After, go to step 2.
        '''
        cdef int_t i, j, scratch
        cdef float_t min_cost
        for i in range(self.n):
            scratch = i*self.c_nd
            #get min
            min_cost = f_inf
            for j in range(self.n):
                if self.c_p[scratch + j] < min_cost:
                    min_cost = self.c_p[scratch + j]
            #subtract min
            for j in range(self.n):
                self.c_p[scratch + j] -= min_cost

        return 2

    cdef int step2(self):
        """
        Find and mark zero in each row. Then go to step 3.
        """
        cdef int_t i, j, scratch
        for i in range(self.n):
            scratch = i*self.c_nd
            for j in range(self.n):
                if self.c_p[scratch + j] == 0 and\
                not self.covered_cols_p[j] and\
                not self.covered_rows_p[i]:
                    self.marked_p[i*self.marked_nd + j] = 1
                    self.covered_cols_p[j] = True
                    self.covered_rows_p[i] = True

        #some cleanup
        for i in range(self.n):
            self.covered_rows_p[i] = False
            self.covered_cols_p[i] = False
        
        return 3

    cdef int step3(self):
        """
        Make sure each column with a star is covered. If not go to step 4.
        Otherwise, the algorithm is done and the exit value of -1 is returned.
        """
        cdef int_t i, j, scratch, c=0
        for i in range(self.n):
            scratch = i*self.marked_nd
            for j in range(self.n):
                if self.marked_p[scratch + j] == 1:
                    self.covered_cols_p[j] = True
                    c += 1

        if c >= self.n:
            #done with algorithm
            return -1
        else:
            #continue to step 4
            return 4

    cdef int step4(self):
        """
        Get a zero and star it. Go to either step 5 or 6 depending on wheter it
        it was found.
        """
        cdef int_t i, j, scratch, done, row, col, col_starred
        done = False
        col_starred = -1
        row = -1
        col = -1
        while True:
            #get zero
            row = -1
            col = -1
            i = 0
            done = False

            while not done:
                j = 0
                while True:
                    scratch = i*self.c_nd
                    if (self.c_p[scratch + j] == 0) and \
                       (not self.covered_rows_p[i]) and \
                       (not self.covered_cols_p[j]):
                        row = i
                        col = j
                        done = True
                    j += 1
                    if j >= self.n:
                        break
                i += 1
                if i >= self.n:
                    done = True
                    
            if row < 0:     #if not found skip to step 6
                return 6
            else:
                self.marked_p[row*self.marked_nd + col] = 2
                #get star
                col_starred = -1
                scratch = row*self.marked_nd
                for j in range(self.n):
                    if self.marked_p[scratch + j] == 1:
                        col_starred = j
                        break
                #change covered status
                if col_starred >= 0:
                    col = col_starred
                    self.covered_cols_p[col] = False
                    self.covered_rows_p[row] = True
                else:
                    self.zero_col = col
                    self.zero_row = row
                    return 5

    cdef int step5(self):
        """
        Create a path that alternates between the two types of zeros.
        """
        cdef int_t i, j, scratch, c, done, row, col
        c = 0
        self.path0_p[c] = self.zero_row
        self.path1_p[c] = self.zero_col
        done = False
        
        while not done:
            #find the star in the current column
            row = -1
            scratch = self.path1_p[c]
            for i in range(self.n):
                if self.marked_p[i*self.marked_nd + scratch] == 1:
                    row = i
                    break
            if row >= 0:
                c += 1
                self.path0_p[c] = row
                self.path1_p[c] = self.path1_p[c-1]
            else:
                done = True

            if not done:
                #get next prime
                col = -1
                scratch = self.path0_p[c]*self.marked_nd
                for j in range(self.n):
                    if self.marked_p[scratch + j] == 2:
                        col = j
                        break
                c += 1
                self.path0_p[c] = self.path0_p[c-1]
                self.path1_p[c] = col

        #put path into marked
        for i in range(c+1):
            scratch = self.path0_p[i]*self.marked_nd + self.path1_p[i]
            if self.marked_p[scratch] == 1:
                self.marked_p[scratch] = 0
            else:
                self.marked_p[scratch] = 1
                
        #cleanup
        for i in range(self.n):
            self.covered_rows_p[i] = False
            self.covered_cols_p[i] = False
        for i in range(self.n):
            scratch = i*self.marked_nd
            for j in range(self.n):
                if self.marked_p[scratch + j] == 2:
                    self.marked_p[scratch + j] = 0
                    
        return 3

    cdef int step6(self):
        """
        Add to covered rows and subtract from uncovered columns.
        """
        cdef int_t i, j, scratch
        #get the minimum
        cdef float_t min_num = f_inf
        for i in range(self.n):
            scratch = i*self.c_nd
            for j in range(self.n):
                if not self.covered_rows_p[i] and not self.covered_cols_p[j]:
                    if min_num > self.c_p[scratch + j]:
                        min_num = self.c_p[scratch + j]
        
        #do adds and subtracts
        for i in range(self.n):
            scratch = i*self.c_nd
            for j in range(self.n):
                if self.covered_rows_p[i]:
                    self.c_p[scratch + j] += min_num
                if not self.covered_cols_p[j]:
                    self.c_p[scratch + j] -= min_num
                    
        #go to step 4
        return 4
        
    cdef full_cleanup(self):
        '''
        Reset all variables before next usage.
        '''
        cdef int_t i, j, scratch
        #reset zero vars
        self.zero_row = 0
        self.zero_col = 0
        #reset marked
        for i in range(self.n):
            scratch = i*self.marked_nd
            for j in range(self.n):
                self.marked_p[scratch + j] = 0
        #reset covers
        for i in range(self.n):
            self.covered_cols_p[i] = False
            self.covered_rows_p[i] = False
        #reset path
        for i in range(self.max_n*2):
            self.path0_p[i] = 0
            self.path1_p[i] = 0
            
h = None
def create_hungarian(n):
    '''
    Creates the initial Munkres object with a max size of n.
    
    :param n: the maximum allowable lap size
    :type n: int
    
    Examples:
    
    Tested in hungarian.
    '''
    global h
    h = Hungarian(n)

def hungarian(np.ndarray cost_matrix, int_t n):
    '''
    Solves the given linear assignment problem by using the hungarian algorithm.
    Returns only the cost.
    
    :param cost_matrix: the LAP cost matrix
    :type cost_matrix: a contiguous ndarray of doubles
    :param n: the size of the given LAP
    :type n: int
    :returns: the cost of the optimal solution to the given LAP
    :rtype: float
    
    Examples:
    
    >>> import numpy as np
    >>> from pyQAP.lap import hungarian
    >>> np.random.seed(123)
    >>> c = np.random.random((5,5))
    >>> hungarian(c,5)
    1.495539557372743
    >>> c = np.random.random((5,5))
    >>> hungarian(c,5)
    1.4573083610474724
    >>> c = np.random.random((10,10))
    >>> hungarian(c,10)
    1.4049864705881505
    >>> c = np.random.random((5,5))*10
    >>> hungarian(c,5)
    11.36045737416805
    >>> c = np.random.random((5,5))*100
    >>> hungarian(c,5)
    83.968980256759025
    >>> c = np.random.random((10,10))*100
    >>> hungarian(c,10)
    95.260588082819268
    '''
    if h == None or h.max_n < n:
        create_hungarian(n)
    #create new numpy array
    cdef np.npy_intp length[2]
    length[0] = n
    length[1] = n
    cdef np.ndarray copy = PyArray_EMPTY(2, length, np.NPY_DOUBLE, 0)
    #copy into new array
    cdef float_t *cost_matrix_p = <float_t*>cost_matrix.data
    cdef float_t *copy_p = <float_t*>copy.data
    cdef int_t i, upper_limit = n*n
    for i in range(upper_limit):
        copy_p[i] = cost_matrix_p[i]
    return h.get_cost(copy, cost_matrix, n)
